LLM 推断指南

借助 LLM Inference API,您可以完全在设备端运行大语言模型 (LLM),并使用这些模型执行各种各样的任务,例如生成文本、以自然语言形式检索信息以及总结文档。该任务内置了对多个文本到文本大型语言模型的支持,因此您可以将最新的设备端生成式 AI 模型应用于您的应用和产品。

试试吧!

该任务内置了对各种 LLM 的支持。LiteRT 社区页面上托管的模型采用 MediaPipe 兼容格式,无需任何额外的转换或编译步骤。

您可以使用 AI Edge Torch 将 PyTorch 模型导出为多签名 LiteRT (tflite) 模型,这些模型会与分词器参数捆绑在一起以创建任务软件包。使用 AI Edge Torch 转换的模型与 LLM Inference API 兼容,并且可以在 CPU 后端上运行,因此适用于 Android 和 iOS 应用。

开始使用

如需开始使用此任务,请按照适用于目标平台的以下任一实现指南操作。以下平台专用指南将引导您完成此任务的基本实现,其中包含使用可用模型和建议的配置选项的代码示例:

任务详情

本部分介绍了此任务的功能、输入、输出和配置选项。

功能

LLM Inference API 包含以下主要功能:

  1. 文本到文本生成 - 根据输入的文本提示生成文本。
  2. LLM 选择 - 应用多个模型,根据您的特定用例量身定制应用。您还可以对模型进行重新训练,并应用自定义权重。
  3. LoRA 支持 - 使用 LoRA 模型扩展和自定义 LLM 功能,方法是基于您的所有数据集进行训练,或从开源社区获取准备好的预构建 LoRA 模型(与使用 AI Edge Torch Generative API 转换的模型不兼容)。
任务输入 任务输出
LLM Inference API 接受以下输入:
  • 文本提示(例如问题、电子邮件主题、要总结的文档)
LLM Inference API 会输出以下结果:
  • 根据输入提示生成的文本(例如,问题的答案、电子邮件草稿、文档摘要)

配置选项

此任务具有以下配置选项:

选项名称 说明 值范围 默认值
modelPath 模型在项目目录中的存储路径。 路径 不适用
maxTokens 模型处理的词元(输入词元 + 输出词元)数量上限。 整数 512
topK 模型在生成过程中每个步骤考虑的令牌数。 将预测限制为前 k 个概率最高的 token。 整数 40
temperature 生成过程中引入的随机性程度。温度越高,生成的文本就越具创造性;温度越低,生成的文本就越具可预测性。 浮点数 0.8
randomSeed 文本生成期间使用的随机种子。 整数 0
loraPath 设备本地 LoRA 模型的绝对路径。注意:此功能仅适用于 GPU 型号。 路径 不适用
resultListener 设置结果监听器以异步接收结果。 仅在使用异步生成方法时适用。 不适用 不适用
errorListener 设置可选的错误监听器。 不适用 不适用

模型

LLM Inference API 支持许多文本到文本大语言模型,包括对经过优化以在浏览器和移动设备上运行的多个模型的内置支持。这些轻量级模型可用于完全在设备端运行推理。

在初始化 LLM Inference API 之前,请下载模型并将该文件存储在项目目录中。您可以使用 LiteRT 社区 HuggingFace 代码库中的预转换模型,也可以使用 AI Edge Torch 生成式转换器将模型转换为与 MediaPipe 兼容的格式。

如果您还没有可与 LLM 推理 API 搭配使用的 LLM,请先使用以下某个模型。

Gemma-3 1B

Gemma-3 1B 是 Gemma 系列先进轻量级开放模型中的最新模型,采用了与 Gemini 模型相同的研究成果和技术。该模型包含 10 亿个参数和开放权重。1B 变体是 Gemma 系列中最轻的模型,非常适合许多设备端用例。

下载 Gemma-3 1B

HuggingFace 中的 Gemma-3 1B 模型采用 .task 格式,可与适用于 Android 和 Web 应用的 LLM 推理 API 搭配使用。

使用 LLM Inference API 运行 Gemma-3 1B 时,请相应地配置以下选项:

  • preferredBackend:使用此选项可在 CPUGPU 后端之间进行选择。此选项仅适用于 Android 设备。
  • supportedLoraRanks:无法将 LLM Inference API 配置为支持使用 Gemma-3 1B 模型进行低秩自适应 (LoRA)。请勿使用 supportedLoraRanksloraRanks 选项。
  • maxTokensmaxTokens 的值必须与内置在模型中的上下文大小一致。这也称为键值 (KV) 缓存或上下文长度。
  • numResponses:始终必须为 1。此选项仅适用于网页版。

在 Web 应用中运行 Gemma-3 1B 时,初始化可能会导致当前线程出现长时间阻塞。请尽可能始终从工作器线程中运行模型。

Gemma-2 2B

Gemma-2 2B 是 Gemma-2 的 2B 变体,适用于所有平台。

下载 Gemma-2 2B

该模型包含 20 亿个参数和开放权重。Gemma-2 2B 以同类模型中先进的推理能力而闻名。

PyTorch 模型转换

您可以使用 AI Edge Torch Generative API 将 PyTorch 生成式模型转换为与 MediaPipe 兼容的格式。您可以使用此 API 将 PyTorch 模型转换为多签名 LiteRT (TensorFlow Lite) 模型。如需详细了解如何映射和导出模型,请访问 AI Edge Torch 的 GitHub 页面

使用 AI Edge Torch Generative API 转换 PyTorch 模型涉及以下步骤:

  1. 下载 PyTorch 模型检查点。
  2. 使用 AI Edge Torch Generative API 编写、转换模型,并将其量化为与 MediaPipe 兼容的文件格式 (.tflite)。
  3. 使用 tflite 文件和模型分词器创建任务软件包 (.task)。

Torch 生成式转换器仅适用于 CPU,并且需要 Linux 机器具有至少 64 GB 的 RAM。

如需创建任务软件包,请使用捆绑脚本创建任务软件包。捆绑过程会将映射的模型与其他元数据(例如分词器参数)来运行端到端推理。

模型捆绑流程需要 MediaPipe PyPI 软件包。转换脚本在 0.10.14 之后的所有 MediaPipe 软件包中均可用。

使用以下命令安装并导入依赖项:

$ python3 -m pip install mediapipe

使用 genai.bundler 库捆绑模型:

import mediapipe as mp
from mediapipe.tasks.python.genai import bundler

config = bundler.BundleConfig(
    tflite_model=TFLITE_MODEL,
    tokenizer_model=TOKENIZER_MODEL,
    start_token=START_TOKEN,
    stop_tokens=STOP_TOKENS,
    output_filename=OUTPUT_FILENAME,
    enable_bytes_to_unicode_mapping=ENABLE_BYTES_TO_UNICODE_MAPPING,
)
bundler.create_bundle(config)
参数 说明 可接受的值
tflite_model AI Edge 导出的 TFLite 模型的路径。 路径
tokenizer_model SentencePiece 词解析器模型的路径。 路径
start_token 模型专用开始令牌。所提供的词解析器模型中必须包含起始令牌。 STRING
stop_tokens 模型专用停止令牌。所提供的词解析器模型中必须包含停止标记。 LIST[STRING]
output_filename 输出任务软件包文件的名称。 路径

LoRA 自定义

Mediapipe LLM Inference API 可配置为支持大语言模型的低秩自适应 (LoRA)。利用微调后的 LoRA 模型,开发者可以通过经济高效的训练流程自定义 LLM 的行为。

LLM Inference API 的 LoRA 支持适用于 GPU 后端的所有 Gemma 变体和 Phi-2 模型,LoRA 权重仅适用于注意力层。此初始实现将作为实验性 API 用于未来开发,我们计划在即将发布的更新中支持更多模型和各种类型的层。

准备 LoRA 模型

按照 HuggingFace 上的说明,使用支持的模型类型(Gemma 或 Phi-2)在您自己的数据集上训练经过微调的 LoRA 模型。Gemma-2 2BGemma 2BPhi-2 模型均以 safetensors 格式在 HuggingFace 上提供。由于 LLM Inference API 仅支持注意力层上的 LoRA,因此在创建 LoraConfig 时,请仅指定注意力层,如下所示:

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

如需进行测试,您可以使用 HuggingFace 上提供的适用于 LLM 推理 API 的公开可用的微调 LoRA 模型。例如,对于 Gemma-2B,使用 monsterapi/gemma-2b-lora-maths-orca-200k;对于 Phi-2,使用 lole25/phi-2-sft-ultrachat-lora

使用准备好的训练数据集进行训练并保存模型后,您会获得一个包含经过微调的 LoRA 模型权重的 adapter_model.safetensors 文件。safetensors 文件是模型转换中使用的 LoRA 检查点。

下一步,您需要使用 MediaPipe Python 软件包将模型权重转换为 TensorFlow Lite Flatbuffer。ConversionConfig 应指定基本模型选项以及其他 LoRA 选项。请注意,由于该 API 仅支持使用 GPU 进行 LoRA 推理,因此后端必须设置为 'gpu'

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

转换器将输出两个 TFLite FlatBuffer 文件,一个用于基准模型,另一个用于 LoRA 模型。

LoRA 模型推理

Web、Android 和 iOS LLM 推理 API 已更新为支持 LoRA 模型推理。

Android 在初始化期间支持静态 LoRA。如需加载 LoRA 模型,用户需要指定 LoRA 模型路径以及基础 LLM。

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPath('<path to base model>')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .setLoraPath('<path to LoRA model>')
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

如需使用 LoRA 运行 LLM 推理,请使用与基准模型相同的 generateResponse()generateResponseAsync() 方法。